{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f02e8f1c7f8>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import pickle\n",
    "import os\n",
    "from packages.vocab import Vocab\n",
    "from packages.batch import Batch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "from models.copynet_dbg import CopyEncoder, CopyDecoder\n",
    "from models.functions import numpy_to_var, to_np, to_var, visualize, decoder_initial, update_logger\n",
    "import time\n",
    "import sys\n",
    "import math\n",
    "torch.manual_seed(1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "embed_size = 150\n",
    "hidden_size = 300\n",
    "num_layers = 1\n",
    "bin_size = 10\n",
    "num_epochs = 40\n",
    "prev_end=0\n",
    "batch_size = 50\n",
    "lr = 0.001\n",
    "vocab_size = 5000\n",
    "weight_decay = 0.99\n",
    "use_saved = True # whether to train from a previous model\n",
    "continue_from = 14\n",
    "step = 0 # number of steps taken\n",
    "epoch = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get vocabulary\n",
    "vocab = Vocab(vocab_size)\n",
    "vocab.w2i = np.load('js_dataset/w2i.npy').item()\n",
    "vocab.i2w = np.load('js_dataset/i2w.npy').item()\n",
    "vocab.count = len(vocab.w2i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_dir = 'js_dataset/var_dataset_3_shorter.txt'\n",
    "# get training and test data\n",
    "with open(file_dir) as f:\n",
    "    lines = f.readlines()\n",
    "test = [line.strip() for line in lines]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch = Batch(file_list=[],max_in_len=30,max_out_len=30,max_oovs=12)\n",
    "batch.num_of_minibatch=len(lines)/batch_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# get number of batches\n",
    "num_samples = len(test)\n",
    "num_batches = int(num_samples/batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home1/irteam/users/mjchoi/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py:284: SourceChangeWarning: source code of class 'torch.nn.modules.rnn.GRU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
      "  warnings.warn(msg, SourceChangeWarning)\n",
      "/home1/irteam/users/mjchoi/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py:284: SourceChangeWarning: source code of class 'models.copynet_dbg.CopyDecoder' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
      "  warnings.warn(msg, SourceChangeWarning)\n",
      "/home1/irteam/users/mjchoi/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py:284: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
      "  warnings.warn(msg, SourceChangeWarning)\n"
     ]
    }
   ],
   "source": [
    "version = 'var_source_code2'\n",
    "encoder = torch.load(f='170704/encoder_%s_%s.pckl' % (version,str(epoch+continue_from)))\n",
    "decoder = torch.load(f='170704/decoder_%s_%s.pckl' % (version,str(epoch+continue_from)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'CopyEncoder' object has no attribute '_forward_pre_hooks'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-13-324a45b7679c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     43\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     44\u001b[0m     \u001b[0;31m# 1.5. encoded outputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 45\u001b[0;31m     \u001b[0mencoded\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mencoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     46\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     47\u001b[0m     \u001b[0;31m# 1.6.1. get initial input of decoder\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    221\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    222\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 223\u001b[0;31m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_pre_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    224\u001b[0m             \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    225\u001b[0m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m    256\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    257\u001b[0m         raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0;32m--> 258\u001b[0;31m             type(self).__name__, name))\n\u001b[0m\u001b[1;32m    259\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    260\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__setattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'CopyEncoder' object has no attribute '_forward_pre_hooks'"
     ]
    }
   ],
   "source": [
    "samples_read = 0\n",
    "################################# testing ##################################\n",
    "\n",
    "# 1. for each epoch\n",
    "# 1.3. initialize entire batch data (no need...\n",
    "batch.init_batch()\n",
    "\n",
    "#. 1.4, for each minibatch\n",
    "test2 = ['var assert = require ( \" assert \" ) ; var util:==:require (  \" util \"  ) ;',\n",
    "         'var object = this . _ object ; var keys:==:this . _ keys ;',\n",
    "         'var pageX = event . pageX ; var pageY:==:event . pageY ;',\n",
    "         \n",
    "         'var dayNamesShort = ( settings ? settings . dayNamesShort : null ) || '+\n",
    "         'this . _ defaults . dayNamesShort ; var dayNames:==:' +\n",
    "         '( settings ? settings . dayNames : null ) || this . _ defaults . dayNames ;',\n",
    "         'var require = arr . require ; var bppp:==:arr . bppp ;',\n",
    "         'var WebChannelDebug = goog . labs . net . webChannel .'+\n",
    "         ' WebChannelDebug ; var Wire:==:goog . labs . net . webChannel . Wire ;',\n",
    "         'var added = e . added ; var removed:==:e . removed ;',\n",
    "         'var readdirSync = require (  \" fs \"  ) . readdirSync ; '+\n",
    "         ' var statSync:==:require ( \" fs \" ) . statSync ;',\n",
    "         'var chunk = entry . chunk ; var encoding:==:entry . encoding ;']\n",
    "# test2 = test\n",
    "\n",
    "correct = 0\n",
    "total = len(test2)\n",
    "print_list = []\n",
    "while(samples_read<len(test2)):\n",
    "\n",
    "    # 1.4.1. initialize gradient buffers\n",
    "    batch.init_minibatch()\n",
    "\n",
    "    # 1.4.2. obtain batch outputs\n",
    "    data = test2[samples_read:min(samples_read+batch_size,len(test))]\n",
    "    inputs, outputs = batch.process_minibatch(data,vocab)\n",
    "    samples_read+=len(data)\n",
    "\n",
    "    # 1.4.3. inputs and outputs must be unk-ed to put into model w/ limited vocab\n",
    "    unked_inputs = batch.unk_minibatch(inputs,vocab)\n",
    "    unked_outputs = batch.unk_minibatch(outputs,vocab)\n",
    "    x = numpy_to_var(unked_inputs)\n",
    "    y = numpy_to_var(unked_outputs)\n",
    "\n",
    "    # 1.5. encoded outputs\n",
    "    encoded, _ = encoder(x)\n",
    "\n",
    "    # 1.6.1. get initial input of decoder\n",
    "    decoder_in, s, w = decoder_initial(x.size(0))\n",
    "    decoder_in = y[:,0]\n",
    "\n",
    "    # 1.7. for each decoder timestep\n",
    "    for j in range(y.size(1)-1): # for all sequences\n",
    "        \"\"\"\n",
    "        decoder_in (Variable): [b]\n",
    "        encoded (Variable): [b x seq x hid]\n",
    "        input_out (np.array): [b x seq]\n",
    "        s (Variable): [b x hid]\n",
    "        \"\"\"\n",
    "        # 1.7.1.1st state - create [out]\n",
    "        if j==0:\n",
    "            out, s, w = decoder(input_idx=y[:,j], encoded=encoded,\n",
    "                            encoded_idx=inputs, prev_state=s,\n",
    "                            weighted=w, order=j)\n",
    "#             out[2,0,vocab.w2i['codeMirror']]=1\n",
    "        # remaining states - add results to [out]\n",
    "        else:\n",
    "            tmp_out, s, w = decoder(input_idx=unked_decoder_in.squeeze(), encoded=encoded,\n",
    "                            encoded_idx=inputs, prev_state=s,\n",
    "                            weighted=w, order=j)\n",
    "            out = torch.cat([out,tmp_out],dim=1)\n",
    "        # for debugging: stop if nan\n",
    "        if math.isnan(w[-1][0][0].data[0]):\n",
    "            print(\"NaN detected!\")\n",
    "            sys.exit()\n",
    "\n",
    "        # 1.8.1. select next input\n",
    "#         decoder_in = y[:,j] # train with ground truth\n",
    "        if j==0:\n",
    "            out[0,-1,vocab.w2i['(']]=1\n",
    "        decoder_in = out[:,-1,:].max(1)[1] # train with prev outputs\n",
    "        unked_decoder_in = batch.unk_minibatch(decoder_in.cpu().data.numpy(),vocab)\n",
    "        unked_decoder_in = Variable(torch.LongTensor(unked_decoder_in).cuda())\n",
    "    # 1.9.1. our targeted outputs should include OOV indices\n",
    "    target_outputs = numpy_to_var(outputs[:,1:])\n",
    "\n",
    "    # 1.9.2. get padded versions of target and output\n",
    "    target = pack_padded_sequence(target_outputs,batch.output_lens.tolist(), batch_first=True)[0]\n",
    "    pad_out = pack_padded_sequence(out,batch.output_lens.tolist(), batch_first=True)[0]\n",
    "    for idx in range(len(data)):\n",
    "        input_print = []\n",
    "        truth_print = []\n",
    "        predict_print = []\n",
    "        for i in inputs[idx]:\n",
    "            if i==0:\n",
    "                break\n",
    "            else:\n",
    "                input_print.append(i)\n",
    "        for i in outputs[idx]:\n",
    "            if i==3:\n",
    "                break\n",
    "            elif i==2:\n",
    "                pass\n",
    "            else:\n",
    "                truth_print.append(i)\n",
    "        for i in out[idx,:,:].max(1)[1].cpu().data.numpy():\n",
    "            if i==3:\n",
    "                break\n",
    "            else:\n",
    "                predict_print.append(i)\n",
    "        line0 = \"\\n===================================================================\"\n",
    "        line1 = 'Input1:       '+''.join(vocab.idx_list_to_word_list(input_print, batch.idx2oov_list[idx]))\n",
    "        line2 = 'Output:       '+''.join(vocab.idx_list_to_word_list(truth_print, batch.idx2oov_list[idx]))\n",
    "        line3 = 'Predict[UNK]: '+''.join(vocab.idx_list_to_word_list(predict_print))\n",
    "        line4 = 'Predicted:    '+''.join(vocab.idx_list_to_word_list(predict_print, batch.idx2oov_list[idx]))\n",
    "        line1 = line1.replace('var', 'var ')\n",
    "        line1 = line1.replace(';',';\\nInput2:       ')\n",
    "        line2 = line2.replace('var', 'var ')\n",
    "        line3 = line3.replace('var', 'var ')\n",
    "        line4 = line4.replace('var', 'var ')\n",
    "        if line2[14:]==line4[14:]:\n",
    "            correct+=1\n",
    "            line4+='\\n***CORRECT***'\n",
    "        print_list.extend([line0,line1,line2,line3,line4])\n",
    "# with open('test_results_%s_epoch_%d_acc_%1.3f.txt' \n",
    "#           %(version,epoch+continue_from,correct*1.0/total),'w') as f:\n",
    "#     f.write('\\n'.join(print_list))\n",
    "print(correct*1.0/total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['\\n===================================================================',\n",
       " 'Input1:       var dayNamesShort=(settings?settings.dayNamesShort:null)||this._defaults.dayNamesShort;\\nInput2:       var dayNames',\n",
       " 'Output:       (settings?settings.dayNames:null)||this._defaults.dayNames;',\n",
       " 'Predict[UNK]: (settings?settings.dayNames:null)||settings.dayNamesShort;',\n",
       " 'Predicted:    (settings?settings.dayNames:null)||settings.dayNamesShort;',\n",
       " '\\n===================================================================',\n",
       " 'Input1:       var WebChannelDebug=goog.labs.net.webChannel.WebChannelDebug;\\nInput2:       var Wire',\n",
       " 'Output:       goog.labs.net.webChannel.Wire;',\n",
       " 'Predict[UNK]: goog.labs.net.<UNK>.<UNK>;',\n",
       " 'Predicted:    goog.labs.net.Wire.Wire;',\n",
       " '\\n===================================================================',\n",
       " 'Input1:       var readdirSync=require( \"fs\" ).readdirSync;\\nInput2:        var statSync',\n",
       " 'Output:       require(\"fs\").statSync;',\n",
       " 'Predict[UNK]: require(\"path\").statSync;',\n",
       " 'Predicted:    require(\"path\").statSync;',\n",
       " '\\n===================================================================',\n",
       " 'Input1:       var assert=require(\"assert\");\\nInput2:       var util',\n",
       " 'Output:       require( \"util\" );',\n",
       " 'Predict[UNK]: require(\"util\");',\n",
       " 'Predicted:    require(\"util\");',\n",
       " '\\n===================================================================',\n",
       " 'Input1:       var object=this._object;\\nInput2:       var keys',\n",
       " 'Output:       this._keys;',\n",
       " 'Predict[UNK]: this._keys;',\n",
       " 'Predicted:    this._keys;\\n***CORRECT***',\n",
       " '\\n===================================================================',\n",
       " 'Input1:       var chunk=entry.chunk;\\nInput2:       var encoding',\n",
       " 'Output:       entry.encoding;',\n",
       " 'Predict[UNK]: entry.encoding;',\n",
       " 'Predicted:    entry.encoding;\\n***CORRECT***',\n",
       " '\\n===================================================================',\n",
       " 'Input1:       var added=e.added;\\nInput2:       var removed',\n",
       " 'Output:       e.removed;',\n",
       " 'Predict[UNK]: e.removed;',\n",
       " 'Predicted:    e.removed;\\n***CORRECT***',\n",
       " '\\n===================================================================',\n",
       " 'Input1:       var require=arr.require;\\nInput2:       var bppp',\n",
       " 'Output:       arr.bppp;',\n",
       " 'Predict[UNK]: arr.<UNK>;',\n",
       " 'Predicted:    arr.bppp;\\n***CORRECT***',\n",
       " '\\n===================================================================',\n",
       " 'Input1:       var pageX=event.pageX;\\nInput2:       var pageY',\n",
       " 'Output:       event.pageY;',\n",
       " 'Predict[UNK]: event.pageY;',\n",
       " 'Predicted:    event.pageY;\\n***CORRECT***']"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "out[2,:,vocab.w2i['fs']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "unked = batch.unk_minibatch(outputs[idx],vocab)\n",
    "' '.join(vocab.idx_list_to_word_list(unked,batch.idx2oov_list[idx]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "' '.join(vocab.idx_list_to_word_list(stories[idx],batch.idx2oov_list[idx]))\n",
    "unked = batch.unk_minibatch(summaries[idx],vocab)\n",
    "' '.join(vocab.idx_list_to_word_list(unked,batch.idx2oov_list[idx]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "' '.join(vocab.idx_list_to_word_list(summaries[idx],batch.idx2oov_list[idx]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "' '.join(vocab.idx_list_to_word_list(summaries[idx]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch.oov2idx_list[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "a = torch.Tensor(3,5,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "vocab.w2i['<EOS>']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# out : [b x seq x vocab] -> [b x seq]"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
